{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torchvision version: 0.8.0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "print('torchvision version:', torchvision.__version__) # Needs at least >= 0.8.0 to do cropping on tensors\n",
    "\n",
    "from vision_transformer import vit_small\n",
    "from vision_transformer4k import vit4k_xs\n",
    "from main_dino4k import DataAugmentationDINO\n",
    "\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ViT-16 Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num Parameters: 21665664\n",
      "1. Input Shape: torch.Size([4, 3, 256, 256])\n",
      "2. Output Shape: torch.Size([4, 384])\n"
     ]
    }
   ],
   "source": [
    "model = vit_small()\n",
    "print(\"Num Parameters:\", count_parameters(model))\n",
    "\n",
    "x = torch.randn(4, 3, 256, 256)\n",
    "print(\"1. Input Shape:\", x.shape)\n",
    "out = model(x)\n",
    "print(\"2. Output Shape:\", out.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ViT-256 Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# of Patches: 196\n",
      "Num Parameters: 2793792\n",
      "1. For a 4K x 4K image, torch.load in 256-len sequence of 384-dim embeddings: torch.Size([256, 384])\n",
      "2. Reshape this sequence to be a 2D image grid (B NC W H): torch.Size([1, 384, 16, 16])\n",
      "3. Applying 2D cropping (B NC W H): torch.Size([1, 384, 14, 14])\n",
      "4. Out: torch.Size([1, 192])\n"
     ]
    }
   ],
   "source": [
    "model = vit4k_xs()\n",
    "print(\"Num Parameters:\", count_parameters(model))\n",
    "\n",
    "t_tensorcrop = transforms.Compose([\n",
    "    transforms.RandomCrop(14), # 14 x 14 for \"global\" crop, 6 x 6 for \"local\" crop\n",
    "])\n",
    "\n",
    "# [14 x 14] crop in a 16 x 16 grid would retain the same relative information as [224 x 224] in a 256 x 256 img\n",
    "assert 224/256 == 14/16 \n",
    "\n",
    "# [6 x 6] crop in a 16 x 16 grid would retain the same relative information as [96 x 96] in a 256 x 256 img\n",
    "assert 96/256 == 6/16 \n",
    "\n",
    "x_bag = torch.randn(256, 384)\n",
    "print('1. For a 4K x 4K image, torch.load in 256-len sequence of 384-dim embeddings:', x_bag.shape)\n",
    "x_bag = x_bag.unsqueeze(dim=0).unfold(1, 16, 16).transpose(1,2)\n",
    "print('2. Reshape this sequence to be a 2D image grid (B NC W H):', x_bag.shape)\n",
    "x_bag = t_tensorcrop(x_bag)\n",
    "print('3. Applying 2D cropping (B NC W H):', x_bag.shape)\n",
    "_ = model(x_bag)\n",
    "print('4. Out:', _.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Crop 1: torch.Size([384, 14, 14])\n",
      "Crop 2: torch.Size([384, 14, 14])\n",
      "Crop 3: torch.Size([384, 6, 6])\n",
      "Crop 4: torch.Size([384, 6, 6])\n",
      "Crop 5: torch.Size([384, 6, 6])\n",
      "Crop 6: torch.Size([384, 6, 6])\n",
      "Crop 7: torch.Size([384, 6, 6])\n",
      "Crop 8: torch.Size([384, 6, 6])\n",
      "Crop 9: torch.Size([384, 6, 6])\n",
      "Crop 10: torch.Size([384, 6, 6])\n"
     ]
    }
   ],
   "source": [
    "t_dino = DataAugmentationDINO(8)\n",
    "\n",
    "x_bag = torch.randn(256, 384)\n",
    "x_crops = t_dino(x_bag)\n",
    "for idx, crop in enumerate(x_crops):\n",
    "    print('Crop %d:' % (idx+1), crop.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
